import jax
import jax.numpy as np
from tqdm import tqdm

def elbo(var_params, rng_key, iter, minibatch, log_p, log_q, sample_q):
	def _chunk_elbo(chunk):
		z = sample_q(rng_key=rng_key, params=var_params, chunk=chunk)
		fix_params = jax.lax.stop_gradient(var_params)
		return (log_p(z=z, params=None, chunk=chunk)
				- log_q(z=z, params=fix_params, chunk=chunk))
	if minibatch is None:
		return _chunk_elbo(None)
	return np.mean(jax.vmap(_chunk_elbo)(minibatch))

def total_elbo(var_params, rng_key, iter, minibatch, log_p, log_q, sample_q):
	def _chunk_elbo(chunk):
		z = sample_q(rng_key=rng_key, params=var_params, chunk=chunk)
		# fix_params = jax.lax.stop_gradient(var_params)
		return (log_p(z=z, params=None, chunk=chunk)
				- log_q(z=z, params=var_params, chunk=chunk))
	if minibatch is None:
		return _chunk_elbo(None)
	return np.mean(jax.vmap(_chunk_elbo)(minibatch))

def cfe_elbo(var_params, rng_key, iter, minibatch, log_p, sample_and_log_q):
	def _chunk_elbo(chunk):
		z, log_q = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return log_p(z=z, params=None, chunk=chunk) - log_q
	if minibatch is None:
		return _chunk_elbo(None)
	return np.mean(jax.vmap(_chunk_elbo)(minibatch))


def eval_elbo(var_params, rng_key, num_copies_eval, log_p, sample_and_log_q, n_chunk, minibatch_use):

	def _chunk_elbo(chunk, rng_key):
		z, log_q = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return (log_p(z=z, params=None, chunk=chunk) - log_q)

	keys = jax.random.split(rng_key, num_copies_eval)
	_chunk_elbo = jax.jit(jax.vmap(_chunk_elbo, in_axes=(None, 0)))
	if minibatch_use:
		elbo = np.zeros(num_copies_eval)
		for chunk in tqdm(range(n_chunk)):
			elbo+= (_chunk_elbo(chunk, keys)/n_chunk)
	else:
		elbo = _chunk_elbo(None, keys)
	return np.mean(elbo)

def eval_train_ll(var_params, rng_key, num_copies_eval, log_p, sample_and_log_q, n_chunk, minibatch_use):
	# assumes the same log_p as eval_test_ll
	def _chunk_elbo(chunk, rng_key):
		z, log_q = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return (log_p(z=z, params=None, chunk=chunk) - log_q)

	keys = jax.random.split(rng_key, num_copies_eval)
	_chunk_elbo = jax.jit(jax.vmap(_chunk_elbo, in_axes=(None, 0)))
	if minibatch_use:
		elbo = np.zeros(num_copies_eval)
		for chunk in tqdm(range(n_chunk)):
			elbo+= (_chunk_elbo(chunk, keys)/n_chunk)
	else:
		elbo = _chunk_elbo(None, keys)
	return jax.scipy.special.logsumexp(elbo) - np.log(num_copies_eval)

def eval_test_ll(var_params, rng_key, num_copies_eval, log_p, sample_and_log_q, n_chunk, minibatch_use):	
	# assums log_p gives the sum of local observation likelihood for a child or for all observations
	def _chunk_ll(chunk, rng_key):
		z, _ = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return log_p(wi=z[1],  chunk=chunk)

	keys = jax.random.split(rng_key, num_copies_eval)
	_chunk_ll = jax.jit(jax.vmap(_chunk_ll, in_axes=(None, 0)))
	if minibatch_use:
		ll = np.zeros(num_copies_eval)
		for chunk in tqdm(range(n_chunk)):
			ll+= _chunk_ll(chunk, keys)
	else:
		ll = _chunk_ll(None, keys)
	return jax.scipy.special.logsumexp(ll) - np.log(num_copies_eval)

def eval_mean_ll(var_params, rng_key, num_copies_eval, log_p, sample_and_log_q, n_chunk, minibatch_use):	
	# assumes log_p returns a vector with probs of all local observations and the idx map
	def _chunk_ll(chunk, rng_key):
		z, _ = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return log_p(wi=z[1],  chunk=chunk)

	keys = jax.random.split(rng_key, num_copies_eval)
	if minibatch_use:
		_chunk_ll = jax.jit(jax.vmap(_chunk_ll, in_axes=(None, 0)))
		# ll = np.zeros(num_copies_eval)
		ll = 0
		for chunk in tqdm(range(n_chunk)):
			_idx, _ll = _chunk_ll(chunk, keys)
			_ll = jax.scipy.special.logsumexp(_ll, 0) - np.log(num_copies_eval)
			ll += np.sum((_idx>=0)*_ll)/(_idx>=0)
	else:
		_chunk_ll = jax.jit(jax.vmap(_chunk_ll, in_axes=(None, 0)))
		_ll = _chunk_ll(None, keys)
		_ll = jax.scipy.special.logsumexp(_ll, 0) - np.log(num_copies_eval)
		
	return np.mean(ll)

def eval_mean_mean_ll(var_params, rng_key, num_copies_eval, log_p, sample_and_log_q, n_chunk, minibatch_use):	

	def _chunk_ll(chunk, rng_key):
		z, _ = sample_and_log_q(rng_key=rng_key, params=var_params, chunk=chunk)
		return log_p(wi=z[1],  chunk=chunk)

	keys = jax.random.split(rng_key, num_copies_eval)
	if minibatch_use:
		_chunk_ll = jax.jit(jax.vmap(_chunk_ll, in_axes=(None, 0)))
		ll = np.zeros(num_copies_eval)
		for chunk in tqdm(range(n_chunk)):
			ll+= _chunk_ll(chunk, keys)/n_chunk
	else:
		_chunk_ll = jax.jit(jax.vmap(_chunk_ll, in_axes=(None, 0)))
		ll = _chunk_ll(None, keys)
	return np.mean(ll)


def _multiple_obj_copies(var_params, rng_key, iter, minibatch, obj, num_copies, agg_func='mean'):
	func = lambda rng_key: obj(var_params, rng_key, iter, minibatch)
	return np.mean(jax.vmap(func)(jax.random.split(rng_key, num_copies)))
